import pickle
import csv
import json
import jsonlines
import pandas as pd
import numpy as np
import random
import sys
sys.path.append(".")
from rob_baseline import *

from tqdm import tqdm


random.seed(42)
import spacy
spnlp = spacy.load('en_core_web_lg')

stop_words = set(["a","an","the","how","who","what","which","where","when","is","was","that","there","and","or","any","if","their","your","you"])
stop_words_more = list(stop_words)

def get_qa_entities(question,answerlist):
    question_entities = []
    qdoc = spnlp(question.lower())
    docverbs=[x.text for x in qdoc if x.pos_=="VERB"]
    question_entities.extend([x.text for x in qdoc.noun_chunks])
    question_entities.extend(docverbs)
    all_ents = [] 
    for ents in question_entities:
        esplits = ents.split(" ")
        esplits = set(esplits)-stop_words
        all_ents.extend(esplits)
        
    question_entities=all_ents
    answerlist = [ x.lower() for x in answerlist]
    answer_entities=[]
    for aix,ans in enumerate(answerlist):
        aents=[]
        ans = ans.lower().replace(".","")
        ans_splits = ans.split(' ')
        if len(ans_splits)==1:
            aents.append(ans)
        else:
            adoc=spnlp(ans)
            ncs = [x.text for x in adoc.noun_chunks]
            ncverbs = [x.text for x in adoc if x.pos_=="VERB"]
            ncs.extend(ncverbs)
            for x in ncs:
                x = x.split(' ')
                x = set(x)-stop_words
                x_all = ' '.join(x)
                aents.append(x_all)
                aents.extend(x)
        answer_entities.append(aents)
    return question_entities,answer_entities

def get_doc(docs):
    if docs in docmap:
        return docmap[docs]
    docmap[docs] = nlp(docs)
    return docmap[docs]

def save_maps(fname,mmap):
    with open(fname, 'wb+') as handle:
        pickle.dump(mmap, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
def load_map(fname):
    return pickle.load(open(fname,'rb'))

def create_norm_reverseindex(mmap,stop_words_more):
    rev_map={}
    stop_words_more = set(stop_words_more)
    for k,v in tqdm(mmap.items(),ascii=True):
        for ent in v:
            ent = ent.lower()
            esplits = ent.split(" ")
            if len(esplits)>=1:
                tokens = set(esplits)-stop_words_more
                for ktok in tokens:
                    klist = rev_map.get(ktok,[])
                    klist.append(k)
                    rev_map[ktok]=klist
    return rev_map

def dedupe_map(mmap):
    new_map={}
    for k,v in tqdm(mmap.items(),ascii=True):
        new_map[k]=list(set(v))
    return new_map


def get_sents_for_ents(ents,rnmap,rvmap):
    sents=[]
    for e in ents:
        sentn=rnmap.get(e,None)
        sentv=rvmap.get(e,None)
        if sentn is not None:
            sents.extend(sentn[0:100])
        if sentv is not None:
            sents.extend(sentv[0:100])
    return set(sents)-set([""])
            
    
def extract_context(rnmap,rvmap,question,answerlist,limit=None):
    qentities,aentities = get_qa_entities(question,answerlist)
    sents =  {}
    
    qsents = get_sents_for_ents(qentities,rnmap,rvmap)
    
    for x in qsents:
        sents[x]=sents.get(x,0)+3
    
    for ans_opts in aentities:
        opt_sents = get_sents_for_ents(ans_opts,rnmap,rvmap)
        intersects = opt_sents.intersection(qsents)
        if limit is None:
            limit=len(intersects)
        if len(intersects)>0:
            for x in list(intersects)[0:limit]:
                sents[x]=sents.get(x,0)+5
        else:
            for x in list(opt_sents)[0:limit]:
                sents[x]=sents.get(x,0)+1
        
    sents = [[k, v] for k, v in sorted(sents.items(), key=lambda item: item[1],reverse=True)]
    return [x[0] for x in sents[0:5]]

def generate_eval(rev_npmap,rev_vpmap,dtype,inpfile,outfile,limit):     
    with jsonlines.open(outfile,"w") as allfd:
        instance_reader = INSTANCE_READERS[dtype]()
        with jsonlines.open(inpfile,"r") as fd:
            for row in tqdm(fd,"Converting:"):
                context, question, label, choices, context_with_choices = \
                instance_reader.fields_to_instance(fields=row)
                context = extract_context(rev_npmap,rev_vpmap,question, choices,limit)
                row["contexts"]=context
                allfd.write(row)
                
def generate_omcs():
    omcs_npmap = load_map(path)
    omcs_vpmap = load_map(path)  

    rev_npmap = create_norm_reverseindex(omcs_npmap,stop_words_more)
    rev_vpmap = create_norm_reverseindex(omcs_npmap,stop_words_more)
    generate_eval(rev_npmap,omcs_vpmap,"commonsenseqa","../data/commonsenseqa/dev_rand_split.jsonl","../data/commonsenseqa/dev_rand_split_ctxt.jsonl")
    generate_eval(rev_npmap,omcs_vpmap,"commonsenseqa","../data/commonsenseqa/train_rand_split.jsonl","../data/commonsenseqa/train_rand_split_ctxt.jsonl")       
    
def generate_obqa():
    omcs_npmap = load_map(path)
    omcs_vpmap = load_map(path)  
    rev_npmap = create_norm_reverseindex(omcs_npmap,stop_words_more)
    rev_vpmap = create_norm_reverseindex(omcs_npmap,stop_words_more)
    generate_eval(rev_npmap,omcs_vpmap,"sci","../data/obqa/dev.jsonl","../data/obqa/dev_ctxt_qasc.jsonl",limit=100)
    generate_eval(rev_npmap,omcs_vpmap,"sci","../data/obqa/train.jsonl","../data/obqa/train_ctxt_qasc.jsonl",limit=100) 
    return rev_npmap,rev_vpmap
    
    
def generate_arc(rev_npmap,rev_vpmap):
    omcs_npmap = load_map(path)
    omcs_vpmap = load_map(path)  
    rev_npmap = create_norm_reverseindex(omcs_npmap,stop_words_more)
    rev_vpmap = create_norm_reverseindex(omcs_npmap,stop_words_more)
    generate_eval(rev_npmap,rev_vpmap,"sci","../data/arc/ARC-Challenge-Dev.jsonl","../data/arc/chall_dev_ctxt.jsonl",limit=100)
    generate_eval(rev_npmap,omcs_vpmap,"sci","../data/arc/ARC-Challenge-Train.jsonl","../data/arc/chall_train_ctxt.jsonl",limit=100) 
    generate_eval(rev_npmap,rev_vpmap,"sci","../data/arc/ARC-Easy-Dev.jsonl","../data/arc/easy_dev_ctxt.jsonl",limit=100)
    generate_eval(rev_npmap,omcs_vpmap,"sci","../data/arc/ARC-Easy-Train.jsonl","../data/arc/easy_train_ctxt.jsonl",limit=100) 
    
    
def generate_qasc(rev_npmap,rev_vpmap):
    omcs_npmap = load_map(path)
    omcs_vpmap = load_map(path)  
    rev_npmap = create_norm_reverseindex(omcs_npmap,stop_words_more)
    rev_vpmap = create_norm_reverseindex(omcs_npmap,stop_words_more)
    generate_eval(rev_npmap,rev_vpmap,"sci","../data/qasc/dev.jsonl","../data/qasc/dev_ctxt_qasc.jsonl",limit=100)
    generate_eval(rev_npmap,omcs_vpmap,"sci","../data/qasc/train.jsonl","../data/qasc/train_ctxt_qasc.jsonl",limit=100) 
    
inp = sys.argv[1]

# generate_omcs()
rev_npmap,rev_vpmap=generate_obqa()
generate_arc(rev_npmap,rev_vpmap)
generate_qasc(rev_npmap,rev_vpmap)


